import torch as pt


def get_device(gpu_id=0):
    device = f'cuda:{gpu_id}' if pt.cuda.is_available() else 'cpu'
    print('device', device)
    device = pt.device(device)
    return device
